{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 感知机Perceporn\n",
    "\n",
    "> 理论 《统计学习方法》第二章 感知机\n",
    "> \n",
    "> 代码 numpy version && torch version\n",
    ">\n",
    "> Python3.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型\n",
    "二分类的线性分类模型，目的是找到一个超平面将所有实例**线性划分**为正例和负例，取值为$+1 -1$。属于是**判别模型**\n",
    "\n",
    "输入空间到输出空间由以下函数进行定义\n",
    "$$\n",
    "f(x) = sign(w\\cdot x + b)\n",
    "$$\n",
    "\n",
    "感知机是在特征空间中所有线性分类模型的**集合**（因为分离超平面不唯一）"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 策略\n",
    "\n",
    "首先要求数据集线性可分，采用基于误分类点的损失，采用梯度下降法对损失函数进行极小化。\n",
    "\n",
    "对于误分类的点，总是有$-y_i(w\\cdot x_i + b) > 0$成立，该式子再乘上一个$1\\over ||w||$（L2范式）就是误分类点到超平面的距离\n",
    "\n",
    "将所有的误分类点到分离超平面的距离进行加总，然后不考虑前面的L2范数，就得到了感知机的损失函数\n",
    "$$\n",
    "Loss(w,b) =  - \\Sigma_{x_i\\in M} y_i(w\\cdot x_i + b),M是误分类点集合\n",
    "$$\n",
    "\n",
    "Loss是非负的，因为没有误分类点时候Loss=0;误分类点越少,Loss越小。\n",
    "\n",
    "此外，Loss是连续可导的。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 算法\n",
    "\n",
    "SGD\n",
    "\n",
    "1. 原始形式\n",
    "   \n",
    "   求解$\\min Loss(w,b)$，每一次随机选择一个误分类点来修正w,b。\n",
    "   Loss关于w,b的梯度分别为\n",
    "   $$\n",
    "      \\Delta_w Loss(w,b) = - \\Sigma y_i x_i\n",
    "      \\\\\n",
    "      \\Delta_b Loss(w,b) = - \\Sigma y_i\n",
    "   $$\n",
    "   给定一个学习率$\\eta$，对w,b进行更新，对每个误分类点(也就是有 $ y_i (w\\cdot x_i + b)\\le 0$成立)\n",
    "   $$\n",
    "      w = w + \\eta y_i x_i\n",
    "      \\\\\n",
    "      b = b + \\eta y_i\n",
    "   $$\n",
    "   迭代Loss不断减少，直到为0，则找到了一个分离超平面   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch import nn\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = np.array([[3.,3.],[4,3],[1,1]])\n",
    "y_train = np.array([1,1,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sign(x):\n",
    "    if x > 0:\n",
    "        return 1\n",
    "    else:\n",
    "        return -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def draw_fig(x,y,w,b):\n",
    "    plt.figure(dpi=64,figsize=(4,4))\n",
    "\n",
    "    # 正负实例的散点图，这儿使用了一个Bool索引，切片混合索引\n",
    "    # plt.scatter(x[y==1][:,0],x[y==1][:,1],color='r')\n",
    "    # plt.scatter(x[y==-1][:,0],x[y==-1][:,1],color='g')\n",
    "    plt.scatter(x[:,0],x[:,1],c=y)\n",
    "\n",
    "    # 画分离超平面\n",
    "    x1 = np.arange(-1, 4, 0.1)\n",
    "    x2 = (w[0] * x1 + b) / (-w[1])\n",
    "    plt.plot(x1,x2)\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1 % w is [3. 3.],b is 1\n",
      "epoch 2 % w is [2. 2.],b is 0\n",
      "epoch 3 % w is [1. 1.],b is -1\n",
      "epoch 4 % w is [0. 0.],b is -2\n",
      "epoch 5 % w is [3. 3.],b is -1\n",
      "epoch 6 % w is [2. 2.],b is -2\n",
      "epoch 7 % w is [1. 1.],b is -3\n",
      "\n",
      "finally get w is [1. 1.], b is -3\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 256x256 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOYAAADhCAYAAADcb8kDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAnYAAAJ2AHHoLmtAAAZxklEQVR4nO3de1zUBb7/8df3OzOAXEREEfCSJqamiaLAkKYs2q5RrVanEkZTFNj8VWuWW25uWT87PX57zu+37Wm7Ct4dsYvaTersnvxVp00uioiXvKSZJGqignJRmJnv+UO3zfLKDPP9zszn+R/gfL/vR+xrme8wMyiapmkIIQxF1XuAEOLnJEwhDEjCFMKAJEwhDEjCFMKAzN46Uf/+/enbt6+3TieET9i3bx+7d+/+2ee9Fmbfvn0pLi721umE8AmZmZkX/bzclRXCgCRMIQxIwhTCgCRMIQzI7TD37NmDxWKhpKTEE3uEEHggzAULFjBmzBhPbBFCnOdWmGVlZcTGxtKjRw+3h6wo+ZbqE01uH0cIf+BWmM8//zxz58695NftdjuZmZlkZmZSU1NzyX93ptWJveRbsgtLOFTX7M4kIfxCm8Ncv349I0aMIDo6+pL/xmazUVxcTHFxMfHx8Zf8dyEWE/bcVDpYTGQXlHC4XuIUga3NYVZWVvLpp58yfvx4/va3vzF79mwOHz7c5iHR4cHYc61YTCrZBaUcPXWmzccSwte1Ocx58+axYcMGPv74Y2699VZefPFF4uLi3BrTNSKYVXmpqApkFZTwvcQpApRHfo+5dOlSrFarJw5FTEQIRXlW0CC7sJRjp8965LhC+BJDPsEgpmMIq/KsOJwubIUl1DZInCKwGDJMgNjIEIryrZxpdTG5sJQTjS16TxLCawwbJkBcZAeK8q00nHVgKyzlpMQpAoShwwTo3qkDRXlWTjW3MnlRKXVNEqfwf4YPE6Bn51CK8qycaGxhyqIy6ptb9Z4kRLvyiTABekWfi/PY6bM8sKiUU2ckTuG/fCZMgN5dwijKt3K4/gxTF5dxWuIUfsqnwgTo0yWMVXlWqk80M21JOQ1nHXpPEsLjfC5MgISYcIryUvn2eCM5S8polDiFn/HJMAH6dYvAnmtl37FGcpaW09QicQr/4bNhAvSPjcCem8reo6eZsXQTzS1OvScJ4RE+HSbAwLiOrMxNZefhU+Qt38SZVolT+D6fDxNgUHwk9txUqr6rI3/FZolT+Dy/CBNgcPdIVuamsuXgSWau3MxZh8QpfJffhAkwpEcnlk9PYdOBkzxkr6DF4dJ7khBt4ldhAgzrFcXS6Sls3Hech1ZV0OqUOIXv8bswAYZfdy7Ov39dyyOrtkicwuf4ZZgAyb07s2RaMp/tOcajqytxSJzCh/htmACp10ezeFoyn+w6yuy3tkqcwmf4dZgAaX2jWTQ1mb/uOMLjb2/F6dL0niTEFfl9mAAjE7pQOHUEH20/wu8kTuEDAiJMgFv6dWXhlOF8WHWYJ9dU4ZI4hYEFTJgA6f1jeH1KEu9VHuKpddskTmFYARUmQMaAbrxqG86aiu/4w3vbJU5hSAEXJsCtN3bj5ewk3iqvZv77O9A0iVMYS0CGCfCrQbH8JWsYq8oO8twHOyVOYShmvQfo6bab4vgPTWPW6kpUReHpOwaiKIres4QI7DAB7hgSj9OlMfvNSkwqPJUpcQr9uRXmgQMHyM7OxmKx4HA4eO211xgyZIintnnNhKHd0TSY/VYlJlXlyfH9JU4fomkOaNkEtEJQCooSrPMeF7RuBq0RLCkoaug1H8OtMHv06MEXX3yBqqps2LCBF154gdWrV7tzSN1MHNYdp0tjzjtbMakw55cSpy/QHAfQTkwBrRnQADNELUQJStRnj/MI2vFs0E6d36OgdXoJNfjmazqOW2Gazf+8+alTp0hM1Oc/hqfcM7wHTk3jiXeqMKkqj916g96TxBVodY+C6+hPPvcwdP0MRfH+Y5ta/VxwHeJclOfVPYoW8wWKEnTVx3H7GrOyspKZM2dSXV3N2rVr3T2c7u4b0ROnS+P3a7dhUhRmjeun9yRxCZp2BpzfXuQLZ8B5EMy9vb6J1kouiPLcIGj9Cq7hp7jbYQ4dOpSNGzdSUVHBgw8+SFlZ2Q9fs9vt2O12AGpqatw9lddkpfTC6dL4w7vbMZsUHvpFgt6TxEVZgItdbjhB7ejtMecFA00/+ZwGatQ1HcWtMM+ePUtw8LkL7cjISEJDL7zItdls2Gw2ADIzM905lddNtl6HS9N45r0dmFSFB8f01XuS+AlFMaF1uBeaV5+/xgQIhqA0FLWzPqPCcqDxjXMP/ABgAfMNKOZe13QYt8L8+9//zrPPPovJZELTNP70pz+5czjDeSCtN06XxnMf7MSkKOSNvl7vSeInlIgn0dRO0GQHnBAyASXiMf32hP0GTekAjYvP3aUOGY8S8cQ1H8etMDMyMsjIyHDnEIaXM7IPTpfG8+u/QlUVZozqo/ck8SOKoqKEz4TwmXpPAUBRFJSwqRA21a3jBPwTDK5G7i3X49I0Fny4E5MC00ZKnKJ9SZhXKX90XxwujWc/2IlJVZiS1lvvScKPSZjX4H+lJ+ByaTz93g5UVcGWep3ek4SfkjCv0cMZ/XC4NOat245JUZiUcm2PtglxNSTMNnh03A24XBq/X7cNVVW4b0RPvScJPyNhttHsW2/A4dJ4ck0VJkXhnuE99J4k/IiE2UaKovC7X/X/0RPfFSYO6673LOEnJEw3KIrC3NsG4HRpPPZWJSZV4c7EeL1nCT8gYbpJURTm3T4Qp6bx6Jvn3gnh9iFxes8SPk7C9ABFUXjmjhtxujR+u3oLJhXGD5Y4RdsF7JtxeZqiKDz360FMSu7Jw6u28NcdR/SeJHyYhOlBiqKwYMJg7h3Rg4dWVfDJV0evfCMhLkLC9DBVVfjXiTdx17DuzFxZwf/f9b3ek4QPkjDbgaoq/J+7h3BnYjy/WbmZz/Yc03uS8DESZjtRVYV/+5ch3H5THPnLN/HF3lq9JwkfImG2I5Oq8H/vTeRXg2KZsaycL7+WOMXVkTDbmUlV+NN9iYy7sRszlm2iZP9xvScJHyBheoHZpPLn+4eS3r8r05eWU/bNCb0nCYOTML3EYlJ5KWsYoxK6kLOkjM3fSpzi0iRML7KYVF7OTiKtbzRTF5dTcfCk3pOEQUmYXhZkVnnFlkRKn85MXVRGZXWd3pOEAUmYOgg2m3jVlsTw3lFMWVTKtu/q9Z4kDEbC1EmIxcTrk4cztGcnJi8qZfshiVP8k4SpoxCLiYIHRjC4e0cmLyplZ80pvScJg5AwdRZiMVH4QDIDYztiKyxh1xGJU0iYhtAhyMSiaSPo1y0CW0Epe46e1nuS0JmEaRChQWaWTEvm+q5hZBeU8PX3EmcgkzANJCzYzJKcFHp1DiWroJR9xxr0niR0ImEaTHiwmWXTU+jeqQNZC0v4prbxyjcSfsetMHfs2MGoUaMYPXo0GRkZ7N+/31O7AlpEiIXlM1KIiwwha2EJ3x6XOAONW2F27dqV9evX8/nnn/PEE0+wYMECT+0KeB1DLCyfnkqXiCCyFpZw8PhP/0qx8GduhRkTE0NkZCQAFosFs1nedM+TIkMtrJyRSqfQILIKSqg+IXEGCo9cYzY3NzN//nwee+zCv+Rrt9vJzMwkMzOTmpoaT5wq4HQKDcKem0pEiJmsghIO1TVf+UbC57kdpsPhYNKkScyZM4eBAwde8DWbzUZxcTHFxcXEx8s7lLdVVNi5OMOCzGQtLOFwvcTp79wKU9M0cnJyGD9+PBMnTvTQJHEx0eHB2PNSCTarZC0s4Uj9Gb0niXbkVpjr169nzZo1vPnmm6Snp/Poo496aJa4mC7hwazKs2JSFbILSvj+lMTpr9x6tOaOO+6gqUkekPCmrhHBFOVZmbSwhKyCEoryrcREhOg9S3iYPMHAB8V0DGFVnhWnS8NWUEptw1m9JwkPkzB9VGxkCEX5Vs46XNgKSjkucfoVCdOHxUV2oCjfSmOLA1thKScbW/SeJDxEwvRx3Tt1oCjPyukz5+Ksa5I4/YGE6Qd6dg6lKM9KXVMLkxeVUt/Uqvck4SYJ00/0ig5lVZ6V2tMtPLC4lPpmidOXSZh+pHeXMIryrRyuP8PUxWWcPiNx+ioJ08/0OR/nobpmpi4uo+GsQ+9Jog0kTD/Ut2s4RXmpHDzRRM6SMholTp8jYfqphJgIVuVZ2X+skZyl5TS1SJy+RML0Yzd0i8Cel8reo6eZvrSc5han3pPEVZIw/dyA2I7Yc63sOnKa3OXlnGmVOH2BhBkAbozvyMoZqWz7rp685ZskTh8gYQaIwd0jsedaqayu4zcrNnPWIXEamYQZQG7qEcmKGalUfHuSmSsrJE4DkzADzNCenVg2I4Wyb07wkH0LLQ6X3pPERUiYASipVxTLpiezcV8tjxRV0OqUOI1GwgxQw6/rzJKcFP57by2zVm+ROA1GwgxgKX06s3haMht2fc/sNytxSJyGIWEGOOv10Syemszfdh7l8be34nRpek8SSJgCuDmhC4umJvPx9iP8TuI0BAlTADCqXxcKHhjBh9sO8+SaKlwSp64kTPGD0Td05Y0pw3m/soa5ayVOPUmY4gK/6B/Da5OTWLflEPPe3S5x6kTCFD8zdmA3XslO4u1N1Tzz/nY0TeL0NglTXNQvB8XycnYSRWXVPPv+DonTyyRMcUnjB8fyl6xhrCw9yP/+cKfE6UXyl2bFZWXeFIfTpTFr9RZMisK82weiKIres/yehOkDNE2jZt8RgkKC6Noj2uvnvzMxHpemMfvNSkwmhbnjB0ic7cytMJuamhg7dixfffUVr7/+OpMmTfLULnHed3tqmHf7C5w60YDm0rhuUE8WvP8kHTtHeHXHhKHdcbo0Hn97K2ZVYc4v+0uc7cita8zg4GDWrVsnfxeznWiaxu9v+1dq9h2l4WQjjfVN7C7by7/nvKLLnruTevBv9wzh1U/38eJ/7dVlQ6Bw6yemyWQiNjbWU1vETxzcdYjGusYLPud0uKjcsB2Xy4Wqev+xu3tH9MSlaTy5ZhsmRWHWuH5e3xAI2vUa0263Y7fbAaipqWnPU/mloGALF3sgVFVVXe9G3p/cC6cLnlq3DbNJ4aFfJOi2xV+1a5g2mw2bzQZAZmZme57KL8Vd343YPjHsr/oW1/mXZAWFWBh9r1X367vs1F44XS6efm8HqqIwM72vrnv8jfwe0+Be+GgeQzMGExIeQoeIENLvH8kjL+fqPQuAKWm9efbOG/njx7tY+Pk+vef4Fbd/Yk6cOJGqqirCwsL48ssveemllzyxS5wXFRPJH//zaRytDhRVwWQy6T3pAtNG9sGpwYIPd6IqCrm3XK/3JL/gdpjvvvuuB2aIKzFbjPsr5xmj+uByaTy//ivMqsK0kX30nuTzjPvdFj4lb/T1OFwaz36wE5OqMCWtt96TfJqEKTxmZnpfXJp27gEhVcGWep3ek3yWhCk86qFfJOBwasxbtx2TojAppZfek3yShCk8bta4fjg1jd+v24aqKtw3oqfek3yOhCnaxexx/XC6XDy5pgqTonDP8B56T/IpEqZoF4py7onuThfMeWcrJlVh4rDues/yGRKmaDeKovDk+P44XS4ee6sSVVX4dWK83rN8goQp2pWiKDyVORCni3Ov51QUbh8Sp/csw5MwRbtTFIWn7xiIS9P47eotqArcdpPEeTkSpvAKRVGYf+eNOF0ajxRt4RVV4VeD5CWDlyJPYhdeoygKz/16EPcl9+ThVRX8186jek8yLAlTeJWqKjw/YTD3JPVgpn0zG3ZJnBcjYQqvU1WFF+66iQlDu/Pgigo+3f293pMMR8IUulBVhT/eM4Q7hsSRv2Iz/733mN6TDEXCFLoxqQr/fm8itw2OJXfZJr78ulbvSYYhYQpdmVSF/3dvIrfe2I3py8op2X9c70mGIGEK3ZlNKn++fygZA2LIWVJO2Tcn9J6kOwlTGILZpPIfk4Yx+oYuTFtSxqYDgR2nhCkMw2JS+UtWEiMTujBtSTkVB0/qPUk3EqYwlCCzyivZSaT26czURWVUVtfpPUkXEqYwnCCzyquTkxjeO4opi0qp+q5O70leJ2EKQwo2m3h98nCG9uzE5MJSth+q13uSV0mYwrBCLCYKHhjBkB6dsBWWsqMmcOKUMIWh/SPOQfEdsRWW8tXhU3pP8goJUxhehyAThVNH0L9bBLbCUnYfOa33pHYnYQqfEBpkZvG0ZBK6hpNdUMLeo/4dp4QpfEZYsJnFOcn07hJGVkEpX3/foPekdiNhCp8SHmxmaU4yPTt3ILughP3H/DNOt8MsLCwkLS2NUaNGUVVV5YlNQlxWRIiFZdNTiIsMIaughAO1jVe+kY9xK8wTJ07w2muv8fnnn7N06VJmz57tqV1CXFbHEAvLZ6QSE3EuzoPHm/Se5FFuhVlaWkp6ejoWi4WEhARqa2txuVye2ibEZUV2sLBiRgqdw4LIKiih+oT/xOlWmCdPniQqKuqHj8PDw6mvD5xfAgv9dQoNYuWMVDp2sJBVUMKhuma9J3mEW2FGRUVRV1f3w8cNDQ1ERkb+8LHdbiczM5PMzExqamrcOZUQlxQVFsTKGSmEBZnJWlhCjR/E6VaYqampfPbZZzgcDr755huio6NR1X8e0mazUVxcTHFxMfHx8tb4ov1Ehwdjz0sl2KySVVDCkfozek9yi1thdu7cmby8PG655RamTJnCiy++6KldQlyzLuHBrMqzYlYVsgpKOHrKd+N0+9cl+fn5bNy4kS+++ILExERPbBKizbpGBFOUZ0UBsgpK+P60b8YpTzAQfiemYwir8qy4XBrZBaXUNpzVe9I1kzCFX4qNDKEo30qr00V2QQnHfSxOCVP4rbjIDhTlWWludWIrLOVEY4vek66ahCn8Wnync3GePuNgcmEpdU2+EaeEKfxej6hQVudbqWtqwVZYSn1Tq96TrkjCFAGhZ+dQivKtnGhsYfKiUuqbjR2nhCkCxnXRYRTlWfn+9BkeWFzGqTPGjVPCFAGld5cwVuVZqalrZuriMk4bNE4JUwScvl3DKcqzUn2imZwl5TScdeg96WckTBGQEmLCKcpL5ZvaRqYvKaepxVhxSpgiYPXrFsGqPCtfH2tg+tJymlucek/6gYQpAlr/2AhWzkhl15HTzFhmnDglTBHwbozviD03lR01p8hfsYkzrfrHKWEKAQyKj8Sem8rW6jryV2zWPU4JU4jzBnePZMWMVLYcPMnMlZs569AvTglTiB9J7NmJ5dNTKD9wkofsFbQ49HlzOQlTiJ8Y1iuKZdOT2bjvOA+vqqDV6f04JUwhLmL4dZ1ZOj2FL76u5bdFW7wep4QpxCUk9+7MkmnJfLr7GI+ursThxTglTCEuI/X6aBZPS+aTXUeZ/dZWr8UpYQpxBWl9o1k0NZm/7jjCnLe34nRp7X5OCVOIqzAyoQsFD4ygePsRfvdO+8cpYQpxlUbf0JWFU4bz4dbDzF1Thasd45QwhbgG6f1jeH1KEu9WHuKpddvaLU4JU4hrlDGgG6/ahvPO5u/4w3vb0TTPxylhCtEGt97YjZezk3irvJr57+/weJwSphBtNH5wLH/JGoa99CDPfbDTo3GaPXYkIQLQbTfF8WeXxqzVWzCpCn+4fSCKorh9XAlTCDfdmRiPS9N4/K2t3DEkjmG9oq58oyto813ZN954g379+jFgwAC3Rwjh6yYM7c6Gx9M9EiW4EeZdd93Fzp07PTJCCH/QKzrUY8dq813ZmJgYj40QQlyoXa8x7XY7drsdgJqamvY8lRB+5bJhNjQ0MG7cuJ99Pjc3l9zc3Cse3GazYbPZAMjMzGzjRCECz2XDDA8Pp6SkxFtbhBDntfnBn7Vr1zJu3Diqq6sZN24cmzZt8uQuIQJam68x7777bu6++25PbhFCnOe1Jxjs27fviteZNTU1xMfHe2nRlcmey5M9l3c1e/bt23fxL2gGctttt+k94QKy5/Jkz+W5s0eexC6EARkqzH/8asUoZM/lyZ7Lc2ePomnt8CpPIYRbDPUTUwhxjoQphAEZLkyjvJyssLCQtLQ0Ro0aRVVVla5bmpqaSEtLo1OnTqxevVrXLQA7duxg1KhRjB49moyMDPbv36/rngMHDnDzzTczZswYRo4cqfv36x/27NmDxWJp27PnPPbYsIccPXpUa2lp0fr376/bhuPHj2tJSUlaS0uLtnfvXi0jI0O3LZqmaQ6HQzt8+LA2f/58raioSNctmnbue1RXV6dpmqZ99NFH2rRp03Td09raqjmdTk3TNO2TTz7R7r//fl33/MPkyZO1sWPHahs3brzm2xruHQyM8HKy0tJS0tPTsVgsJCQkUFtbi8vlQlX1uYNhMpmIjY3V5dwX8+PvkcViwWzW939GPz7/qVOnSExM1HHNOWVlZcTGxmIymdp0e8PdlTWCkydPEhX1z1eih4eHU19fr+MiY2pubmb+/Pk89thjek+hsrKStLQ0Hn74YcaOHav3HJ5//nnmzp3b5tvr8n917r6crL1FRUVRV1f3w8cNDQ1ERkbqN8iAHA4HkyZNYs6cOQwcOFDvOQwdOpSNGzdSUVHBgw8+SFlZmW5b1q9fz4gRI4iOjm7zMXQJ0+gvJ0tNTeWZZ57B4XBQXV1NdHS0bndjjUjTNHJychg/fjwTJ07Uew5nz54lODgYgMjISEJDPfcWH21RWVnJp59+ypdffsm2bdvYvXs3a9euJS4u7uoP4vlLXvesWbNGGzt2rBYaGqqNHTtWKy8v12XHG2+8oVmtVm3kyJFaZWWlLht+bMKECVqfPn20wYMHa4888oiuWz744AOtQ4cO2pgxY7QxY8Zos2bN0nXPJ598ot1yyy1aenq6NmbMGG3z5s267vmxqVOntunBH3nmjxAGJPfPhDAgCVMIA5IwhTAgCVMIA5IwhTAgCVMIA5IwhTCg/wFyijiznjuILAAAAABJRU5ErkJggg==\n"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 定义超参\n",
    "lr = 1\n",
    "w = np.zeros(x_train.shape[1])\n",
    "b = 0\n",
    "epoch = 1\n",
    "\n",
    "flag = True\n",
    "while flag:\n",
    "    label = np.zeros_like(y_train)\n",
    "    for i,v in enumerate(x_train):\n",
    "        y_hat = (v*w).sum() + b\n",
    "        if sign(y_hat) != y_train[i]:\n",
    "            w = w + lr * v * y_train[i]\n",
    "            b = b + lr * y_train[i]\n",
    "            print(\"epoch {} % w is {},b is {}\".format(epoch,w,b))\n",
    "            epoch += 1\n",
    "        label[i] = sign(y_hat)\n",
    "    if (label == y_train).all():\n",
    "        flag = False\n",
    "\n",
    "print(\"\\nfinally get w is {}, b is {}\".format(w,b))\n",
    "\n",
    "draw_fig(x_train,y_train,w,b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Perception(nn.Module):\n",
    "    def __init__(self,n):\n",
    "        super(Perception,self).__init__()\n",
    "        self.w = nn.Parameter(torch.ones((n,1)))\n",
    "        self.b = nn.Parameter(torch.zeros((1,1)))\n",
    "\n",
    "    def forward(self,x):\n",
    "        pred = torch.mm(x,self.w) + self.b\n",
    "        pred = torch.sign(pred)\n",
    "        return pred\n",
    "\n",
    "    def loss_func(self,x,y):\n",
    "        return -y * (torch.mm(x,self.w)+ self.b).squeeze()\n",
    "\n",
    "    def init_param(self):\n",
    "        for name,param in self.named_parameters():\n",
    "            nn.init.kaiming_uniform_(param)\n",
    "\n",
    "def SGD_optimizer(net,input_x,label):\n",
    "    y_pred = net(input_x)\n",
    "    y_pred = y_pred.squeeze()\n",
    "    non_zero = torch.nonzero(torch.eq(y_pred,label)).squeeze(1)\n",
    "    index = (set(range(len(label))) - set(non_zero.tolist()))\n",
    "    if len(index) == 0:\n",
    "        return -1,-1\n",
    "    index = list(index)\n",
    "\n",
    "    index = random.choice(index)\n",
    "    loss = net.loss_func(input_x[index].unsqueeze(0),label[index])\n",
    "    return loss, 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done!\n",
      "\n",
      "finally get w is [-0.8050575  3.1798725], b is [-5.560445]\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 256x256 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAO0AAADhCAYAAAAkjzL0AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAnYAAAJ2AHHoLmtAAAdXklEQVR4nO3de1hTd54/8DeEIIZoDBRF26JyEcE7olxjjthWjRadGWdAcbx0ZKhdW7vibqt20f7a6v60dmfcocqUqWM7UHVGQKuRmZYp1HpBK5OCoogGFUXRgES55/LdP6ypEQU1wMkhn9fz9Hnge3JOPqm8yYWcd5wYYwyEEMFw5nsAQsiTodASIjAUWkIEhkJLiMBQaAkRGBe+BwgMDISfnx/fYxBidy5cuICysrI267yH1s/PD2q1mu8xCLE7KpXqoev08JgQgaHQEiIwFFpCBIZCS4jAtBvaixcvIjIyEkqlElFRUSguLrbartVqwXEcoqKisGnTJst6eno6IiIiEB0d3WYfQroLM5TDXLsA5upQmGvmgLX+i995jJdhrk2EuXoCzLpYmJuPPOWB2mEwGJjJZGKMMZaXl8fi4uKsts+ZM4cdPnyYmc1mxnEc02q1rKamhoWEhLDW1lZWXl7OYmJi2rsKNn369Ha3E/I0zEYdM12fwEzXAn7673oIMxvK+ZnHdIeZqsMfmGccM7doHrnPo7LR7j2ti4sLnJ3vXuT27dsYM2aM1fbS0lJERkbCyckJKpUKhw4dQmFhITiOg1gshr+/P3Q6Hcxm89P9RiHkKbGmPQBreGCxHqxhOz/zNOcC5saHzLP1iY/V4d9pNRoNli5disrKSmRlZVltuz+M/fr1Q01NDVxcXCCXyy3rUqkUer3eai0jIwMZGRkAgKqqqicempAOmW8AMDywyABTNR/TAOYaAM1t1003n/hQHb4QNXbsWBw9ehT79u3DsmXLrHd2/ml3vV4PDw8PyOVy1NXVWdbr6+shk8ms9ktISIBarYZarcagQYOeeGhCOuLkNhVwkj6wKgHcXuZnnl4c4OT+wKob0HvmEx+r3dC2tLRYvpbJZJBIJFbbg4ODcfz4cQBAbm4uFAoFwsLCUFBQAKPRiIqKCnh6elqFm5BuIQ4Fev/ibnCdpIBTH6CXAk5PEZLO4CQOBCS/uW+evoBrCJwkCU98rHYfHh8+fBjr1q2DSCQCYwwfffQRcnNzUVtbi3nz5mHDhg1YsmQJDAYDYmNj4evrCwBITEyEQqGASCRCamrq091KQmzg5OQEp75rwCQLAeMZwMUXTi78vsfduc+/gUnmAIZiQPQ8nMTDn+o4TozxWzejUqnovceEPMSjskGPWwkRGAotIQJDoSVEYCi0hAgMhZYQgaHQEiIwFFpCBIZCS4jAUGgJERgKLSECQ6ElRGAotIQIDIWWEIGh0BIiMO2G9vTp04iOjsakSZMQExMDrVZrtX3WrFngOA4cx8Hd3d3SvCiRSCzr2dnZXTc9IQ6o3ZPgvby8cODAAchkMuTm5uK9997D9u0/FWPt3bsXAKDT6cBxHEaPHg0A8PHxQX5+ftdNTYgDa/eetn///pZ+J7FYDBeXh2d8165diIuLs3xfVVUFpVKJ+Ph43LhxoxPHJYQ81nPapqYmrF27FitWrHjo9szMTMybN8/yfUVFBQoKChAbG4vk5OTOmZQQAevMgpgOQ2s0GhEfH4+VK1ciKCiozfaKigowxqw+Y9bT0xMAEBcXB41G02afjIwMqFQqqFQqqlAlPZq+yYDV2SV498vSTjtmu6FljGHx4sWYNm0aZs+e/dDLZGRkWN3LNjQ0wGQyAQAKCgoe+oHRVKFKejrGGPYXV+GFjwqQe+o6xj7fr9OO3e4LUQcOHMCePXtQWVmJXbt2YezYsZg2bZqljREAdu/ejby8PMs+Z8+eRWJiIqRSKcRiMdLS0jptWEKEoLK2ESl7T+Gbspv4VehzWDU9CHJ31047fruhnTlzJhobG9u7SJsP2Bo/fjyKiopsn4wQgTGazNh++CI++uocBvZzw87fhiPc17PTr6fDjwUhhHSs+EodVmWVoLy6Hks5P7w22Q+9XERdcl0UWkJsUN9ixOZ/lGHHkYsIHeIB9XIF/Ps/+HEknYtCS8hT+qq0Gil7T6Gx1YQNPx+FX45/Hs7OTl1+vRRaQp7QdX0z1u07jdzT1zFr7CC8MyMYXn16ddv1U2gJeUwmM0NG4SVszC2D3F2MHa9MhHKYV7fPQaEl5DGcuXYbq7JKUHJVj0SFL5ZPCUBv1655oakjFFpC2tHUasLv88qRfkiLkc/KsP/1aAQN7MvrTBRaQh7h23M3sSanBLcaDEh5ORgJYYMh6oYXmjpCoSXkAbr6Fry3vxR7NVWYNsIb62JHwFvmxvdYFhRaQn7EGMPu7yuxXn0WElcRPlkQiheDB/A9VhsUWkIAnL9Rj9XZJfj+Yi0WRg5B8kuBkPayz3jY51SEdJMWowkff3MBW/MvwL+/FNmvRWFMJ56R0xUotMRhHdPWYHV2Ca7VNeM/pgZicdQQuIjsv+uQQkscTl1jK9arz2D391cwOdALOxZPxPMeEr7HemwUWuIwGGPYq6nCe/tL4ezshD/MG4cZowbCyYn/P+M8CZsqVBctWoSQkBBwHIekpCTLenp6OiIiIhAdHd3mfFtC+HCppgELPj2ON3dpMHWkN75eocTM0YMEF1jAxgpVAPj4448RHh5u+b62thZbt27FsWPHcOnSJSQlJVk1WxDSnQwmMz45pMXvvy6Hj4cEf3s1AqFDPPgeyybthrZ///6Wrx9Vobp8+XK4ublh1apVmDZtGgoLC8FxHMRiMfz9/aHT6WA2m+HsbP9P8EnPUnT5FlZnlUCra8AbMf747SQ/uLoI/+fwsZ7T3qtQ/eSTT6zWN2/eDE9PT1RXV2Py5MkIDw/HrVu3IJfLLZeRSqXQ6/VWaxkZGcjIyAAAamMkne52swEbc88io/AyInw9sXX+eAx9xp3vsTpNh6Ftr0L1XlXqgAEDEBoaivLycsjlctTV1VkuU19fbyk8vychIQEJCQkAAJVKZettIATA3ReaDp66jnX7TsNgMuPDOWPw85BnBfm8tT3thrajClW9Xg+ZTIampiZoNBoMHjwYfn5+SElJgdFoRGVlJTw9PemhMelyV+uakJJzCnlnb+AXIc9hzYwgeHRiA6I9salCde7cudDr9TAYDEhOTrY8B05MTIRCoYBIJEJqamq33BDimExmhj8fuYjN/yjDgL5uyEwMQ6TfM3yP1aWcWGd+XsFTUKlUUKvVfI5ABOrUVT1WZZXg7PXbWKr0w2uT/eEm5ufE9K7wqGzQmyuI4DS0GPE/X53Dp4crMH6wHOo3FAgY0IfvsboNhZYISt6ZaqTsPY07zQa8P3sU4id0TwOiPaHQEkG4cbsZ735ZigMl1zBz9ECkvByM/n3s58T07kShJXbNbGbIPH4Z/z/3LPq6ibF98QRMDuzf8Y49GIWW2K2y63ewOrsEmso6/CZ6KN58IQASV/qRpf8DxO40G0z433+WI61AixGD+mLfsiiMGCTreEcHQaElduW7ch3W5JRAd6cFa2YEYUHEELtoQLQnFFpiF2rqW/DBgTPI+tdVvBQ8ADt/G46Bst58j2WXKLSEV4wx/PXkFaxXn4Gbiwjb5o/HtJHefI9l1yi0hDfam3cbEAsrarEgfDBWTg1EHzcx32PZPQot6XYtRhO25WuR+s15+Hq5I2tpJMb5yDvekQCg0JJudryiFquzS3DlViOSXxqGV6KHQiyABkR7QqEl3ULfaMCGg2ew80QllMO8sH3RBEE1INoTCi3pUowx7PvhbgMiAGyZOw4vjxZeA6I9samNcf78+YiMjERYWBh27NhhWZdIJOA4DhzHITs7u2smJ3avsrYRC7efwPKdGrwYPAB5KzjEjhFmA6JdYe2orq5mdXV1jDHGDh48yBYtWmS1/dy5c4wxxpqbm1lAQABraWlhjDEWGBjY3mGtTJ8+/bEvS4Sh1WhiW/PPs8B31GzK5nxWqK3heyRBelQ2bGpjDAgIAAC4urrC2dnZ8hu0qqoKSqUSAwcOxJYtW6yOQ3o2TWUdVmWV4MKNeiyL8UeS0he9XHrOien2wKY2xns2btyI+Ph4iMV3/8ZWUVEBT09PZGZmIjk5GZ9//rnV5amNsee502zAh38vw2fHLiFsqAcOvqmAn5eU77F6po7uog0GA4uNjWXZ2dkP3Z6RkcF+9atfMZPJ1Gab0WhkI0eOfKqHAEQ4DpZcY2EffM3GvPt3tvvEZWY2m/keqUd4qofHrIM2RrVaje3bt2P//v2WxsWGhga4ublBJBKhoKAAfn5+XfLLhvDvmr4JKXtP46vSavxs3LN4Z0YQPKW9+B6rx7OpjXHBggXw8fHB1KlTAQA7d+7E1atXkZiYCKlUCrFYjLS0tG65IaT7mMwMnx29iA//XoZn+vTC57+ZCEWAF99jOYx2Qztz5kw0NjY+crtOp2uz5u3tjaKiItsnI3bpdJUeq7NKcLrqNpKUvng9JqBHNSAKAb25gjyWxlYjfvd1Of70XQXGPCfDgTcUCPR2nAZEe0KhJR36puwG/ivnFPSNBrwbOwLzJvo4XAOiPaHQkke6cacZ/+/LUuwvvoYZo+42IA7o65gNiPaEQkvaMJsZdp6oxH8fPIM+bmL8aWEopgQN4Hss8iMKLbFSXn23AfHkpVt4JWoo/v3FYXDvRT8m9oT+NQiAuw2Iqd+cx7aCCxju3Rf7lkVj5LPUgGiPKLQER87rsCbnFKpvN+Pt6UFYGDEYLnRiut2i0Dqw2oZWfHDgDPYUXcELQf3xlyVheLYfNSDaOwqtA2KMIavoKj5Qn4FY5IRt80MwdYQ3necqEBRaB1Oha8A7OSU4cqEGv/6xAbEvNSAKCoXWQbQazfjjtxew5Z/n4fuMO/YsjUQINSAKEoXWAXx/sRarskpwubYR//7CMCxRUAOikFFoezB9kwH/ffAsvjh+GYqAZ5C+MBSDPd35HovYiELbAzHGsL/4Gt79shSMMfwubixmjaVCtZ7CpjZGrVYLjuMQFRWFTZs2WdbT09MRERGB6OhoFBcXd83k5KEqaxvxyp9P4PUv/oWY4V7IS1Zi9rhnKbA9SXt1Fx21Mc6ZM4cdPnyYmc1mxnEc02q1rKamhoWEhLDW1lZWXl7OYmJinqpSgzwZg9HE0grOs+HvHGSTP/yGHb2g43skYqMuaWMsLS1FZGQkAEClUuHQoUPw8vICx3EQi8Xw9/eHTqeD2Wy21NGQzvfDjw2I52/UYynnh9cm+1EDYg9mUxuj2Wy2fN2vXz/U1NTAxcUFcvlPf0qQSqXQ6/VWa6Rz1LcYsfkfZdhx5CJCh3hAvVwB//7UgNjTdRhao9GI+Ph4rFy5EkFBQVbb7r/31Ov18PLyglwuR11dnWW9vr4eMpn1G8+pQtV2X5VWI2XvKTS2mrDh56Pwy/HP04npDqLdx6ysgzbG4OBgHD9+HACQm5sLhUKBsLAwFBQUwGg0WvqPH3xonJCQALVaDbVajUGDBnXerXEA1/XNePXzk0j87HtMHOqBvGQl4iZQk4QjsamNccOGDViyZAkMBgNiY2Ph6+sLAEhMTIRCoYBIJEJqamq33JCezmRmyCi8hI25ZZC7i7HjlYlQDqMGREfkxBhjfA6gUqmgVqv5HMHunbl2G6uySnDqqh6Jk3zxRkwAervSC0093aOyQW+usGNNrSb8Pq8c6Ye0GPWcDPvfiMZw7758j0V4RqG1U9+eu4k1OSWoazBg7cvBmBc2GCJ63kpAobU7uvoWvLe/FHs1VZg+0hvrYkdQAyKxQqG1E4wx7P6+EuvVZyFxFeGTBaF4MZgaEElbFFo7cP5GPVZnl+D7i7VYGDkEyS8FQkoNiOQR6CeDR80GEz7Ov4Ct+ecxbEAf5PxbFEY/14/vsYido9Dy5Ji2BquzS3CtrhlvTRuORZFDqAGRPBYKbTe71dCK9eoz+OvJK5gc6IUdiyfieQ8J32MRAaHQdhPGGHI0V/H+/jNwdnbCH+aNw4xRA+k8V/LEKLTd4FJNA97JOYVD5TrMC/PBW9OGQ9abGhDJ06HQdiGDyYxPDmnx+6/L4eMhwd9ejUDoEA++xyICR6HtIkWXb2F1Vgm0ugYsnxKARIUvXF3ohSZiOwptJ7vdbMDG3LPIKLyMSD9PbJs/HkOeoQZE0nkotJ2EMYaDp65j3b7TMJoZNv9yDH5GhWqkC1BoO8HVuiak5JxC3tkbmDP+OaxWBcHD3ZXvsUgP1e6TrMbGRkRERKBfv37YuXNnm+2zZs0Cx3HgOA7u7u6WulSJRGJZz87O7prJ7YDRZEb6IS1e/KgAWl0DMhPD8OEvx1BgSZdq9562V69eyM7OxrZt2x66fe/evQAAnU4HjuMwevRoAICPjw/y8/M7d1I7c+qqHm9nFaPs+h0sVfrhtcn+cBPTiemk67UbWpFIBG9v7w4PsmvXLsTFxVm+r6qqglKpxMCBA7FlyxarKlaha2gx4n++OodPD1dg/GA51G8oEDCgD99jEQfSKc9pMzMz8dlnn1m+v1folpmZieTkZHz++edWlxdqG+M/z1bjv3JO406zAR/8bBTiQqkBkXQ/m/9wWFFRAcYY/Pz8LGuenp4AgLi4OGg0mjb7CK2N8cbtZryWcRKv/Pl7hAyW4+tkJeZOpAZEwg+b72kzMjIwb948y/cNDQ1wc3ODSCRCQUGBVZiFxmxmyDh+GRsPnoVMIsafF08AF9hzHuoTYeowtLNnz0ZxcTHc3d1x5MgRqFQqS4UqAOzevRt5eXmWy589exaJiYmQSqUQi8VIS0vruum7UNn1O1iVVYwfruixJHoolr8QAIkr/YWM8K/Dn8KcnJx2tz/4qXjjx49HUVGRTUPxqdlgwpa8cvzxWy1GDOqLfcuiMGKQrOMdCekmdNdxn+/KdViTUwLdnRa8MyMIv44YQg2IxO5QaAHU1Lfg/QNnkP2vq3gpeAB2/jYcA2W9+R6LkIdy6NAyxvDXk1ewXn0Gbi4ipP16PKaO6Pjv0oTwyWFDq715twGxsKIWC8IHY+XUQPRxoxPTif1zuNC2GE3Ylq9F6jfn4ddfiuzXojD2+X58j0XIY3Oo0B6vqMXq7BJcudWI5JeG4ZXooRBTAyIRGIcIrb7RgA0Hz2DniUooh3lh+6IJ1IBIBKtHh5Yxhn0/VOG9/aUAnLBl7ji8PJoaEImw9djQVtY2Yk3OKXx77ibmTvTB29OGQyahF5qI8PW40BpMZvzpuwr87utzeE4uwe6kCEwcSg2IpOfoUaHVVNbh7T3F0N5swLIYfyQpfdHLhU5MJz1LjwjtnWYDPvx7GT47dglhQz1w8E0F/LykfI9FSJcQfGhzf2xAbDaasPEXozFn/HP0QhPp0QQb2qq6Jqzbdxr/KK3Gz8c9izUzguAp7cX3WIR0OZvaGBctWoSQkBBwHIekpCTLenp6OiIiIhAdHd3m1D1bmcwMn35XgRc/KkBZ9R385Tdh+ChuLAWWOAyb2hgB4OOPP0Z4eLjl+9raWmzduhXHjh3DpUuXkJSUZHWSvC1OV+mxOqsEp6tuI0npi9djAqgBkTicdu9pH6eNcfny5VAqlcjNzQUAFBYWguM4iMVi+Pv7Q6fTwWw22zRkY6sR69VnEPuHwxA5O2H/G9H4j6nDKbDEIdn0nHbz5s3w9PREdXU1Jk+ejPDwcNy6dQtyudxyGalUCr1eb7X2JI5pa5C8+wfcbjbg3dgRmEeFasTB2RTae62LAwYMQGhoKMrLyyGXy1FXV2e5TH19PWQy67qWJ6lQFTk7Yezz/bD25WD07+tmy7iE9Ag2hVav10Mmk6GpqQkajQaDBw+Gn58fUlJSYDQaUVlZCU9PTzg7Wz8KT0hIQEJCAgBApVK1ex0ThnhgAn2mKyEWNrUxzp07F3q9HgaDAcnJyZZPEkhMTIRCoYBIJEJqamqX3whCHIkTY4zxOYBKpYJareZzBELs0qOyQWeAEyIwFFpCBIZCS4jAUGgJERgKLSECQ6ElRGAotIQIDIWWEIGh0BIiMBRaQgSGQkuIwFBoCREYCi0hAkOhJURgKLSECIxNFarz589HZGQkwsLCsGPHDsu6RCIBx3HgOA7Z2dmdPzUhDsymCtW1a9ciICAALS0tGDVqFObOnQtXV1f4+PggPz+/K+YlxOHZVKEaEBAAAHB1dYWzs7Pl4ziqqqqgVCoRHx+PGzdudOK4hJBOeU67ceNGxMfHQyy++/mvFRUVKCgoQGxsLJKTk9tcPiMjAyqVCiqVqsM2RkKINZtDm5mZiaKiIqSkpFjW7lWrxsXFQaPRtNknISEBarUaarUagwYNsnUEQhyKTRWqarUa27dvx/79+y01qQ0NDXBzc4NIJEJBQQH8/Pw6ZVBCyF02VaguWLAAPj4+mDp1KgBg586duHr1KhITEyGVSiEWi5GWltblN4IQR9JhaHNych65TafTtVnz9vZGUVGRTUMRQh6N3lxBiMBQaAkRGAotIQJDoSVEYCi0hAgMhZYQgaHQEiIwFFpCBIZCS4jAUGgJERgKLSECQ6ElRGAotIQIDIWWEIGxqY1Rq9WC4zhERUVh06ZNlvX09HREREQgOjoaxcXFnT81AQDortYgdfmnWMGl4C/v/Q0Ntxv5Hol0A5vaGN966y2sX78eERERiImJwZw5cyCTybB161YcO3YMly5dQlJSEvLy8rpkeEd280oNlob8J+7cqofZZMbZwvP4+i/f4o8/fAhXN1e+xyNdyKY2xtLSUkRGRsLJyQkqlQqHDh1CYWEhOI6DWCyGv78/dDodzGZzpw/u6HZtzLEEFgAMLQbUXruFQ3sKeZ6MdDWbntPeH8Z+/fqhpqYGt27dglwut6xLpVLo9Xqr/aiN0XYXNBctgb2nqb4Z2uKL/AxEuo1Nob1X5gYAer0eHh4ekMvlqKurs6zX19dDJpNZ7UdtjLYbN2UUXFytn91I+vbG6EnBPE1EuotNoQ0ODsbx48cBALm5uVAoFAgLC0NBQQGMRiMqKirg6elpFW7SOX7x5gz093kGblI3AHcDOyzUDxOmj+N5MtLVbGpj3LBhA5YsWQKDwYDY2Fj4+voCABITE6FQKCASiZCamtrlN8IRucvc8UnxZuTvPoKKkssYPSkYE1Xj6BekA3BijDE+B1CpVFCr1XyOQIhdelQ26NcyIQJDoSVEYCi0hAgMhZYQgbHpA7g6w4ULF6BSqdq9TFVVlV39PZfmaZ+9zQPY30yPM8+FCxcevoEJwPTp0/kewQrN0z57m4cx+5vJlnno4TEhAiOI0CYkJPA9ghWap332Ng9gfzPZMg/vb64ghDwZQdzTEkJ+QqElRGAEFdq0tDQEBARg+PDhvM1gb1U6HVUCdafTp08jOjoakyZNQkxMDLRaLa/zXLx4EZGRkVAqlYiKirKLfy8AOHfuHMRiMY4dO/Z0B+i017C7QXV1NWttbWWBgYG8XH9NTQ0LCQlhra2trLy8nMXExPAyx/2MRiO7du0aW7t2Lfviiy94naW6uprV1dUxxhg7ePAgW7RoEa/zGAwGZjKZGGOM5eXlsbi4OF7nuWf+/PlsypQp7OjRo0+1P+9vrngS/fv35/X6H1Wlw+fpcB1VAnWn+/99xGIxXFz4/fG6//pv376NMWPG8DjNXcePH4e3tzdEItFTH0NQD4/59jhVOgRoamrC2rVrsWLFCr5HgUajQUREBJYtW4YpU6bwPQ7ef/99vP322zYdw+7uaevr6/HCCy+0WV+yZAmWLFnCw0Q/eZwqHUdnNBoRHx+PlStXIigoiO9xMHbsWBw9ehRFRUV49dVXLU0rfDhw4ABCQ0Ph6elp03HsLrRSqfTpn6B3sbCwMKSkpMBoNKKyspKqdB7AGMPixYsxbdo0zJ49m+9x0NLSgl69egEAZDIZJBIJr/NoNBrk5+fjyJEjKCkpQVlZGbKysjBw4MAnO1DnPsXuWnv27GFTpkxhEomETZkyhZ04caLbZ0hLS2Ph4eEsKiqKaTSabr/+h5k1axYbOnQoGzlyJHv99dd5m+PLL79kvXv3ZkqlkimVSrZ8+XLeZmHs7otPCoWCcRzHlEolO3nyJK/z3G/hwoVP/UIUvSOKEIGhx3aECAyFlhCBodASIjAUWkIEhkJLiMBQaAkRGAotIQLzfwTDzOWjVNdoAAAAAElFTkSuQmCC\n"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "x = torch.from_numpy(x_train)\n",
    "x = x.float()\n",
    "y = torch.from_numpy(y_train)\n",
    "y = y.float()\n",
    "\n",
    "model = Perception(len(x[0]))\n",
    "model.init_param()\n",
    "lr = 1\n",
    "while True:\n",
    "    loss, _ = SGD_optimizer(model, x, y)\n",
    "    if _ == -1:\n",
    "        print('Done!')\n",
    "        break\n",
    "\n",
    "    loss.backward()\n",
    "    with torch.no_grad():\n",
    "        for param in model.parameters():\n",
    "            # print(param.grad)\n",
    "            param -= param.grad * lr\n",
    "\n",
    "w = model.w.data.numpy()[:,0]\n",
    "b = model.b.data.numpy()[0]\n",
    "print(\"\\nfinally get w is {}, b is {}\".format(w,b))\n",
    "\n",
    "draw_fig(x_train,y_train,w,b)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 对偶形式\n",
    "   \n",
    "   将w,b表示为实例$(x_i,y_i)$的线性组合，来进行求解\n",
    "\n",
    "   与原始形式相比，将w,b都初始化为0,令$\\alpha_i = \\eta_i\\ \\eta$,就可以将w,b表示为\n",
    "   $$\n",
    "   w =  \\Sigma \\alpha_i y_i x_i\n",
    "      \\\\\n",
    "      b = \\Sigma \\alpha_i y_i\n",
    "   $$\n",
    "\n",
    "   对偶形式的输出为$\\alpha = {(\\alpha_1,..,\\alpha_N)^T}$,$b$。\n",
    "   \n",
    "   在训练数据集中选择$(x_i.y_i)$，若满足$y_i(\\Sigma_{j=1}^N \\alpha_j y_j x_j \\cdot x_i + b) \\le 0$。则更新$\\alpha_i = \\alpha_i + \\eta, b = b + \\eta y_i$。\n",
    "   \n",
    "   一直到没有误分类点"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "   $$\n",
    "        \\Sigma_{j=1}^N \\alpha_j y_j x_j \\cdot x_i = \\Sigma_{j=1}^N \\alpha_j y_j G[j,i]\n",
    "   $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1 : alpha is [1. 0. 0.],b is 1\n",
      "epoch 2 : alpha is [1. 0. 1.],b is 0\n",
      "epoch 3 : alpha is [1. 0. 2.],b is -1\n",
      "epoch 4 : alpha is [1. 0. 3.],b is -2\n",
      "epoch 5 : alpha is [2. 0. 3.],b is -1\n",
      "epoch 6 : alpha is [2. 0. 4.],b is -2\n",
      "epoch 7 : alpha is [2. 0. 5.],b is -3\n",
      "\n",
      "finally get alpha is [2. 0. 5.], b is -3\n",
      "\n",
      "finally get w is [1. 1.], b is -3\n"
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 256x256 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOYAAADhCAYAAADcb8kDAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAAnYAAAJ2AHHoLmtAAAZxklEQVR4nO3de1zUBb7/8df3OzOAXEREEfCSJqamiaLAkKYs2q5RrVanEkZTFNj8VWuWW25uWT87PX57zu+37Wm7Ct4dsYvaTersnvxVp00uioiXvKSZJGqignJRmJnv+UO3zfLKDPP9zszn+R/gfL/vR+xrme8wMyiapmkIIQxF1XuAEOLnJEwhDEjCFMKAJEwhDEjCFMKAzN46Uf/+/enbt6+3TieET9i3bx+7d+/+2ee9Fmbfvn0pLi721umE8AmZmZkX/bzclRXCgCRMIQxIwhTCgCRMIQzI7TD37NmDxWKhpKTEE3uEEHggzAULFjBmzBhPbBFCnOdWmGVlZcTGxtKjRw+3h6wo+ZbqE01uH0cIf+BWmM8//zxz58695NftdjuZmZlkZmZSU1NzyX93ptWJveRbsgtLOFTX7M4kIfxCm8Ncv349I0aMIDo6+pL/xmazUVxcTHFxMfHx8Zf8dyEWE/bcVDpYTGQXlHC4XuIUga3NYVZWVvLpp58yfvx4/va3vzF79mwOHz7c5iHR4cHYc61YTCrZBaUcPXWmzccSwte1Ocx58+axYcMGPv74Y2699VZefPFF4uLi3BrTNSKYVXmpqApkFZTwvcQpApRHfo+5dOlSrFarJw5FTEQIRXlW0CC7sJRjp8965LhC+BJDPsEgpmMIq/KsOJwubIUl1DZInCKwGDJMgNjIEIryrZxpdTG5sJQTjS16TxLCawwbJkBcZAeK8q00nHVgKyzlpMQpAoShwwTo3qkDRXlWTjW3MnlRKXVNEqfwf4YPE6Bn51CK8qycaGxhyqIy6ptb9Z4kRLvyiTABekWfi/PY6bM8sKiUU2ckTuG/fCZMgN5dwijKt3K4/gxTF5dxWuIUfsqnwgTo0yWMVXlWqk80M21JOQ1nHXpPEsLjfC5MgISYcIryUvn2eCM5S8polDiFn/HJMAH6dYvAnmtl37FGcpaW09QicQr/4bNhAvSPjcCem8reo6eZsXQTzS1OvScJ4RE+HSbAwLiOrMxNZefhU+Qt38SZVolT+D6fDxNgUHwk9txUqr6rI3/FZolT+Dy/CBNgcPdIVuamsuXgSWau3MxZh8QpfJffhAkwpEcnlk9PYdOBkzxkr6DF4dJ7khBt4ldhAgzrFcXS6Sls3Hech1ZV0OqUOIXv8bswAYZfdy7Ov39dyyOrtkicwuf4ZZgAyb07s2RaMp/tOcajqytxSJzCh/htmACp10ezeFoyn+w6yuy3tkqcwmf4dZgAaX2jWTQ1mb/uOMLjb2/F6dL0niTEFfl9mAAjE7pQOHUEH20/wu8kTuEDAiJMgFv6dWXhlOF8WHWYJ9dU4ZI4hYEFTJgA6f1jeH1KEu9VHuKpddskTmFYARUmQMaAbrxqG86aiu/4w3vbJU5hSAEXJsCtN3bj5ewk3iqvZv77O9A0iVMYS0CGCfCrQbH8JWsYq8oO8twHOyVOYShmvQfo6bab4vgPTWPW6kpUReHpOwaiKIres4QI7DAB7hgSj9OlMfvNSkwqPJUpcQr9uRXmgQMHyM7OxmKx4HA4eO211xgyZIintnnNhKHd0TSY/VYlJlXlyfH9JU4fomkOaNkEtEJQCooSrPMeF7RuBq0RLCkoaug1H8OtMHv06MEXX3yBqqps2LCBF154gdWrV7tzSN1MHNYdp0tjzjtbMakw55cSpy/QHAfQTkwBrRnQADNELUQJStRnj/MI2vFs0E6d36OgdXoJNfjmazqOW2Gazf+8+alTp0hM1Oc/hqfcM7wHTk3jiXeqMKkqj916g96TxBVodY+C6+hPPvcwdP0MRfH+Y5ta/VxwHeJclOfVPYoW8wWKEnTVx3H7GrOyspKZM2dSXV3N2rVr3T2c7u4b0ROnS+P3a7dhUhRmjeun9yRxCZp2BpzfXuQLZ8B5EMy9vb6J1kouiPLcIGj9Cq7hp7jbYQ4dOpSNGzdSUVHBgw8+SFlZ2Q9fs9vt2O12AGpqatw9lddkpfTC6dL4w7vbMZsUHvpFgt6TxEVZgItdbjhB7ejtMecFA00/+ZwGatQ1HcWtMM+ePUtw8LkL7cjISEJDL7zItdls2Gw2ADIzM905lddNtl6HS9N45r0dmFSFB8f01XuS+AlFMaF1uBeaV5+/xgQIhqA0FLWzPqPCcqDxjXMP/ABgAfMNKOZe13QYt8L8+9//zrPPPovJZELTNP70pz+5czjDeSCtN06XxnMf7MSkKOSNvl7vSeInlIgn0dRO0GQHnBAyASXiMf32hP0GTekAjYvP3aUOGY8S8cQ1H8etMDMyMsjIyHDnEIaXM7IPTpfG8+u/QlUVZozqo/ck8SOKoqKEz4TwmXpPAUBRFJSwqRA21a3jBPwTDK5G7i3X49I0Fny4E5MC00ZKnKJ9SZhXKX90XxwujWc/2IlJVZiS1lvvScKPSZjX4H+lJ+ByaTz93g5UVcGWep3ek4SfkjCv0cMZ/XC4NOat245JUZiUcm2PtglxNSTMNnh03A24XBq/X7cNVVW4b0RPvScJPyNhttHsW2/A4dJ4ck0VJkXhnuE99J4k/IiE2UaKovC7X/X/0RPfFSYO6673LOEnJEw3KIrC3NsG4HRpPPZWJSZV4c7EeL1nCT8gYbpJURTm3T4Qp6bx6Jvn3gnh9iFxes8SPk7C9ABFUXjmjhtxujR+u3oLJhXGD5Y4RdsF7JtxeZqiKDz360FMSu7Jw6u28NcdR/SeJHyYhOlBiqKwYMJg7h3Rg4dWVfDJV0evfCMhLkLC9DBVVfjXiTdx17DuzFxZwf/f9b3ek4QPkjDbgaoq/J+7h3BnYjy/WbmZz/Yc03uS8DESZjtRVYV/+5ch3H5THPnLN/HF3lq9JwkfImG2I5Oq8H/vTeRXg2KZsaycL7+WOMXVkTDbmUlV+NN9iYy7sRszlm2iZP9xvScJHyBheoHZpPLn+4eS3r8r05eWU/bNCb0nCYOTML3EYlJ5KWsYoxK6kLOkjM3fSpzi0iRML7KYVF7OTiKtbzRTF5dTcfCk3pOEQUmYXhZkVnnFlkRKn85MXVRGZXWd3pOEAUmYOgg2m3jVlsTw3lFMWVTKtu/q9Z4kDEbC1EmIxcTrk4cztGcnJi8qZfshiVP8k4SpoxCLiYIHRjC4e0cmLyplZ80pvScJg5AwdRZiMVH4QDIDYztiKyxh1xGJU0iYhtAhyMSiaSPo1y0CW0Epe46e1nuS0JmEaRChQWaWTEvm+q5hZBeU8PX3EmcgkzANJCzYzJKcFHp1DiWroJR9xxr0niR0ImEaTHiwmWXTU+jeqQNZC0v4prbxyjcSfsetMHfs2MGoUaMYPXo0GRkZ7N+/31O7AlpEiIXlM1KIiwwha2EJ3x6XOAONW2F27dqV9evX8/nnn/PEE0+wYMECT+0KeB1DLCyfnkqXiCCyFpZw8PhP/0qx8GduhRkTE0NkZCQAFosFs1nedM+TIkMtrJyRSqfQILIKSqg+IXEGCo9cYzY3NzN//nwee+zCv+Rrt9vJzMwkMzOTmpoaT5wq4HQKDcKem0pEiJmsghIO1TVf+UbC57kdpsPhYNKkScyZM4eBAwde8DWbzUZxcTHFxcXEx8s7lLdVVNi5OMOCzGQtLOFwvcTp79wKU9M0cnJyGD9+PBMnTvTQJHEx0eHB2PNSCTarZC0s4Uj9Gb0niXbkVpjr169nzZo1vPnmm6Snp/Poo496aJa4mC7hwazKs2JSFbILSvj+lMTpr9x6tOaOO+6gqUkekPCmrhHBFOVZmbSwhKyCEoryrcREhOg9S3iYPMHAB8V0DGFVnhWnS8NWUEptw1m9JwkPkzB9VGxkCEX5Vs46XNgKSjkucfoVCdOHxUV2oCjfSmOLA1thKScbW/SeJDxEwvRx3Tt1oCjPyukz5+Ksa5I4/YGE6Qd6dg6lKM9KXVMLkxeVUt/Uqvck4SYJ00/0ig5lVZ6V2tMtPLC4lPpmidOXSZh+pHeXMIryrRyuP8PUxWWcPiNx+ioJ08/0OR/nobpmpi4uo+GsQ+9Jog0kTD/Ut2s4RXmpHDzRRM6SMholTp8jYfqphJgIVuVZ2X+skZyl5TS1SJy+RML0Yzd0i8Cel8reo6eZvrSc5han3pPEVZIw/dyA2I7Yc63sOnKa3OXlnGmVOH2BhBkAbozvyMoZqWz7rp685ZskTh8gYQaIwd0jsedaqayu4zcrNnPWIXEamYQZQG7qEcmKGalUfHuSmSsrJE4DkzADzNCenVg2I4Wyb07wkH0LLQ6X3pPERUiYASipVxTLpiezcV8tjxRV0OqUOI1GwgxQw6/rzJKcFP57by2zVm+ROA1GwgxgKX06s3haMht2fc/sNytxSJyGIWEGOOv10Syemszfdh7l8be34nRpek8SSJgCuDmhC4umJvPx9iP8TuI0BAlTADCqXxcKHhjBh9sO8+SaKlwSp64kTPGD0Td05Y0pw3m/soa5ayVOPUmY4gK/6B/Da5OTWLflEPPe3S5x6kTCFD8zdmA3XslO4u1N1Tzz/nY0TeL0NglTXNQvB8XycnYSRWXVPPv+DonTyyRMcUnjB8fyl6xhrCw9yP/+cKfE6UXyl2bFZWXeFIfTpTFr9RZMisK82weiKIres/yehOkDNE2jZt8RgkKC6Noj2uvnvzMxHpemMfvNSkwmhbnjB0ic7cytMJuamhg7dixfffUVr7/+OpMmTfLULnHed3tqmHf7C5w60YDm0rhuUE8WvP8kHTtHeHXHhKHdcbo0Hn97K2ZVYc4v+0uc7cita8zg4GDWrVsnfxeznWiaxu9v+1dq9h2l4WQjjfVN7C7by7/nvKLLnruTevBv9wzh1U/38eJ/7dVlQ6Bw6yemyWQiNjbWU1vETxzcdYjGusYLPud0uKjcsB2Xy4Wqev+xu3tH9MSlaTy5ZhsmRWHWuH5e3xAI2vUa0263Y7fbAaipqWnPU/mloGALF3sgVFVVXe9G3p/cC6cLnlq3DbNJ4aFfJOi2xV+1a5g2mw2bzQZAZmZme57KL8Vd343YPjHsr/oW1/mXZAWFWBh9r1X367vs1F44XS6efm8HqqIwM72vrnv8jfwe0+Be+GgeQzMGExIeQoeIENLvH8kjL+fqPQuAKWm9efbOG/njx7tY+Pk+vef4Fbd/Yk6cOJGqqirCwsL48ssveemllzyxS5wXFRPJH//zaRytDhRVwWQy6T3pAtNG9sGpwYIPd6IqCrm3XK/3JL/gdpjvvvuuB2aIKzFbjPsr5xmj+uByaTy//ivMqsK0kX30nuTzjPvdFj4lb/T1OFwaz36wE5OqMCWtt96TfJqEKTxmZnpfXJp27gEhVcGWep3ek3yWhCk86qFfJOBwasxbtx2TojAppZfek3yShCk8bta4fjg1jd+v24aqKtw3oqfek3yOhCnaxexx/XC6XDy5pgqTonDP8B56T/IpEqZoF4py7onuThfMeWcrJlVh4rDues/yGRKmaDeKovDk+P44XS4ee6sSVVX4dWK83rN8goQp2pWiKDyVORCni3Ov51QUbh8Sp/csw5MwRbtTFIWn7xiIS9P47eotqArcdpPEeTkSpvAKRVGYf+eNOF0ajxRt4RVV4VeD5CWDlyJPYhdeoygKz/16EPcl9+ThVRX8186jek8yLAlTeJWqKjw/YTD3JPVgpn0zG3ZJnBcjYQqvU1WFF+66iQlDu/Pgigo+3f293pMMR8IUulBVhT/eM4Q7hsSRv2Iz/733mN6TDEXCFLoxqQr/fm8itw2OJXfZJr78ulbvSYYhYQpdmVSF/3dvIrfe2I3py8op2X9c70mGIGEK3ZlNKn++fygZA2LIWVJO2Tcn9J6kOwlTGILZpPIfk4Yx+oYuTFtSxqYDgR2nhCkMw2JS+UtWEiMTujBtSTkVB0/qPUk3EqYwlCCzyivZSaT26czURWVUVtfpPUkXEqYwnCCzyquTkxjeO4opi0qp+q5O70leJ2EKQwo2m3h98nCG9uzE5MJSth+q13uSV0mYwrBCLCYKHhjBkB6dsBWWsqMmcOKUMIWh/SPOQfEdsRWW8tXhU3pP8goJUxhehyAThVNH0L9bBLbCUnYfOa33pHYnYQqfEBpkZvG0ZBK6hpNdUMLeo/4dp4QpfEZYsJnFOcn07hJGVkEpX3/foPekdiNhCp8SHmxmaU4yPTt3ILughP3H/DNOt8MsLCwkLS2NUaNGUVVV5YlNQlxWRIiFZdNTiIsMIaughAO1jVe+kY9xK8wTJ07w2muv8fnnn7N06VJmz57tqV1CXFbHEAvLZ6QSE3EuzoPHm/Se5FFuhVlaWkp6ejoWi4WEhARqa2txuVye2ibEZUV2sLBiRgqdw4LIKiih+oT/xOlWmCdPniQqKuqHj8PDw6mvD5xfAgv9dQoNYuWMVDp2sJBVUMKhuma9J3mEW2FGRUVRV1f3w8cNDQ1ERkb+8LHdbiczM5PMzExqamrcOZUQlxQVFsTKGSmEBZnJWlhCjR/E6VaYqampfPbZZzgcDr755huio6NR1X8e0mazUVxcTHFxMfHx8tb4ov1Ehwdjz0sl2KySVVDCkfozek9yi1thdu7cmby8PG655RamTJnCiy++6KldQlyzLuHBrMqzYlYVsgpKOHrKd+N0+9cl+fn5bNy4kS+++ILExERPbBKizbpGBFOUZ0UBsgpK+P60b8YpTzAQfiemYwir8qy4XBrZBaXUNpzVe9I1kzCFX4qNDKEo30qr00V2QQnHfSxOCVP4rbjIDhTlWWludWIrLOVEY4vek66ahCn8Wnync3GePuNgcmEpdU2+EaeEKfxej6hQVudbqWtqwVZYSn1Tq96TrkjCFAGhZ+dQivKtnGhsYfKiUuqbjR2nhCkCxnXRYRTlWfn+9BkeWFzGqTPGjVPCFAGld5cwVuVZqalrZuriMk4bNE4JUwScvl3DKcqzUn2imZwl5TScdeg96WckTBGQEmLCKcpL5ZvaRqYvKaepxVhxSpgiYPXrFsGqPCtfH2tg+tJymlucek/6gYQpAlr/2AhWzkhl15HTzFhmnDglTBHwbozviD03lR01p8hfsYkzrfrHKWEKAQyKj8Sem8rW6jryV2zWPU4JU4jzBnePZMWMVLYcPMnMlZs569AvTglTiB9J7NmJ5dNTKD9wkofsFbQ49HlzOQlTiJ8Y1iuKZdOT2bjvOA+vqqDV6f04JUwhLmL4dZ1ZOj2FL76u5bdFW7wep4QpxCUk9+7MkmnJfLr7GI+ursThxTglTCEuI/X6aBZPS+aTXUeZ/dZWr8UpYQpxBWl9o1k0NZm/7jjCnLe34nRp7X5OCVOIqzAyoQsFD4ygePsRfvdO+8cpYQpxlUbf0JWFU4bz4dbDzF1Thasd45QwhbgG6f1jeH1KEu9WHuKpddvaLU4JU4hrlDGgG6/ahvPO5u/4w3vb0TTPxylhCtEGt97YjZezk3irvJr57+/weJwSphBtNH5wLH/JGoa99CDPfbDTo3GaPXYkIQLQbTfF8WeXxqzVWzCpCn+4fSCKorh9XAlTCDfdmRiPS9N4/K2t3DEkjmG9oq58oyto813ZN954g379+jFgwAC3Rwjh6yYM7c6Gx9M9EiW4EeZdd93Fzp07PTJCCH/QKzrUY8dq813ZmJgYj40QQlyoXa8x7XY7drsdgJqamvY8lRB+5bJhNjQ0MG7cuJ99Pjc3l9zc3Cse3GazYbPZAMjMzGzjRCECz2XDDA8Pp6SkxFtbhBDntfnBn7Vr1zJu3Diqq6sZN24cmzZt8uQuIQJam68x7777bu6++25PbhFCnOe1Jxjs27fviteZNTU1xMfHe2nRlcmey5M9l3c1e/bt23fxL2gGctttt+k94QKy5/Jkz+W5s0eexC6EARkqzH/8asUoZM/lyZ7Lc2ePomnt8CpPIYRbDPUTUwhxjoQphAEZLkyjvJyssLCQtLQ0Ro0aRVVVla5bmpqaSEtLo1OnTqxevVrXLQA7duxg1KhRjB49moyMDPbv36/rngMHDnDzzTczZswYRo4cqfv36x/27NmDxWJp27PnPPbYsIccPXpUa2lp0fr376/bhuPHj2tJSUlaS0uLtnfvXi0jI0O3LZqmaQ6HQzt8+LA2f/58raioSNctmnbue1RXV6dpmqZ99NFH2rRp03Td09raqjmdTk3TNO2TTz7R7r//fl33/MPkyZO1sWPHahs3brzm2xruHQyM8HKy0tJS0tPTsVgsJCQkUFtbi8vlQlX1uYNhMpmIjY3V5dwX8+PvkcViwWzW939GPz7/qVOnSExM1HHNOWVlZcTGxmIymdp0e8PdlTWCkydPEhX1z1eih4eHU19fr+MiY2pubmb+/Pk89thjek+hsrKStLQ0Hn74YcaOHav3HJ5//nnmzp3b5tvr8n917r6crL1FRUVRV1f3w8cNDQ1ERkbqN8iAHA4HkyZNYs6cOQwcOFDvOQwdOpSNGzdSUVHBgw8+SFlZmW5b1q9fz4gRI4iOjm7zMXQJ0+gvJ0tNTeWZZ57B4XBQXV1NdHS0bndjjUjTNHJychg/fjwTJ07Uew5nz54lODgYgMjISEJDPfcWH21RWVnJp59+ypdffsm2bdvYvXs3a9euJS4u7uoP4vlLXvesWbNGGzt2rBYaGqqNHTtWKy8v12XHG2+8oVmtVm3kyJFaZWWlLht+bMKECVqfPn20wYMHa4888oiuWz744AOtQ4cO2pgxY7QxY8Zos2bN0nXPJ598ot1yyy1aenq6NmbMGG3z5s267vmxqVOntunBH3nmjxAGJPfPhDAgCVMIA5IwhTAgCVMIA5IwhTAgCVMIA5IwhTCg/wFyijiznjuILAAAAABJRU5ErkJggg==\n"
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "Gram = np.dot(x_train,x_train.T)\n",
    "\n",
    "# 定义超参\n",
    "lr = 1\n",
    "alpha = np.zeros(x_train.shape[0])\n",
    "b = 0\n",
    "epoch = 1\n",
    "\n",
    "flag = True\n",
    "while flag:\n",
    "    label = np.zeros_like(y_train)\n",
    "    for i in range(len(x_train)):\n",
    "        sum = 0\n",
    "        for j in range(len(x_train)):\n",
    "            sum += alpha[j] * y_train[j] * Gram[j][i]\n",
    "        y_hat = sum + b\n",
    "        if sign(y_hat) != y_train[i]:\n",
    "            alpha[i] += lr\n",
    "            b += lr * y_train[i]\n",
    "            print(\"epoch {}: alpha is {},b is {}\".format(epoch,alpha,b))\n",
    "            epoch += 1\n",
    "        label[i] = sign(y_hat)\n",
    "    if (label == y_train).all():\n",
    "        flag = False\n",
    "\n",
    "print(\"\\nfinally get alpha is {}, b is {}\".format(alpha,b))\n",
    "\n",
    "w = np.zeros(x_train.shape[1])\n",
    "for i in range(len(alpha)):\n",
    "    for j in range(len(w)):\n",
    "        w[j] += alpha[i] * label[i] * x_train[i][j]\n",
    "        \n",
    "print(\"\\nfinally get w is {}, b is {}\".format(w,b))\n",
    "draw_fig(x_train,y_train,w,b)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4 (default, Aug  9 2019, 18:34:13) [MSC v.1915 64 bit (AMD64)]"
  },
  "vscode": {
   "interpreter": {
    "hash": "3c874ee3d65450ef899cacaa06bdca40f9e4450c5e631a94911db3215dbda5b8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
